from typing import Dict, List
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, Qwen2ForSequenceClassification, LlamaForSequenceClassification
import torch
import torch.nn as nn
import numpy as np
import os



class CustomClassifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CustomClassifier, self).__init__()
        self.scores = nn.Sequential(
            nn.Linear(input_dim, 2048, bias=False), nn.SiLU(), nn.Dropout(0.1),
            nn.Linear(2048, 1024, bias=False), nn.SiLU(),
            nn.Linear(1024, 1024, bias=False), nn.SiLU(),
            nn.Linear(1024, 1024, bias=False), nn.SiLU(), nn.Dropout(0.1),
            nn.Linear(1024, output_dim, bias=False)
        )
        self.weights = nn.Sequential(
            nn.Linear(input_dim, 1024, bias=False), nn.SiLU(), nn.Dropout(0.1),
            nn.Linear(1024, 1024, bias=False), nn.SiLU(),
            nn.Linear(1024, 1024, bias=False), nn.SiLU(), nn.Dropout(0.1),
            nn.Linear(1024, output_dim, bias=False)
        )
        self.gatings = nn.Sequential(
            nn.Linear(input_dim, 1024, bias=False), nn.SiLU(), nn.Dropout(0.1),
            nn.Linear(1024, 1024, bias=False), nn.SiLU(),
            nn.Linear(1024, 1024, bias=False), nn.SiLU(), nn.Dropout(0.1),
            nn.Linear(1024, output_dim, bias=False)
        )

    def forward(self, x):
        scores = self.scores(x)
        weights = self.weights(x)
        gatings = self.gatings(x)
        return torch.cat([scores, weights, gatings], dim=-1)

class NewQwen2ForSequenceClassification(LlamaForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.score = CustomClassifier(config.hidden_size, config.num_labels//2)

class ArmoRMPipeline:
    def __init__(self, model_id, device_map="auto", torch_dtype=torch.bfloat16, truncation=True, trust_remote_code=False, max_length=8192):
        self.model = NewQwen2ForSequenceClassification.from_pretrained(
            model_id,
            device_map=device_map,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            use_fast=True,
        )
        self.truncation = truncation
        self.device = self.model.device
        self.max_length = max_length

    def __call__(self, messages: List[Dict[str, str]]) -> Dict[str, float]:
        """
        messages: OpenAI chat messages to be scored
        Note: no batching since due to length differences, the model will have to pad to the max length which is not efficient
        Returns: a dictionary with the score between 0 and 1
        """
        assert len(messages) <= 2, 'Too many messages for the model'
        messages = [[{'role': 'user', 'content': messages[0][0]['content']}]] + messages
        inputs = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
        input_ids = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=self.truncation,
            max_length=self.max_length
        ).to(self.device)

        with torch.no_grad():
            try:
                outputs = self.model(**input_ids)
                logits = outputs.logits
                scores, weights, gatings = torch.split(logits, logits.size(-1)//3, dim=-1)
                weights_clone = weights[0].unsqueeze(0).repeat(weights.size(0), 1)
                weights = weights_clone

                weights = torch.where(weights > 0.0, 1.0, 0.0)
                masked_weights = torch.where(weights.bool(), gatings, torch.tensor(float('-inf'), device=gatings.device))
                masked_weights = masked_weights.softmax(dim=-1)

                overall_scores = (masked_weights * scores.sigmoid()).sum(dim=-1)
                overall_scores = torch.tensor(overall_scores[1:].clone().detach(), dtype=torch.float32)

                
            except Exception as e:
                print(e)
                raise e

        return overall_scores